from scipy import *
from numpy import *
from visual import *
from random import random
import time


class ElectronWave:
    def __init__(self, sizeOfLattice, center, elec, direction, mode, coeff):
        self.sizeOfLattice = sizeOfLattice
        self.center = center
        self.elec = elec
        self.direction = direction
        self.mode = mode
        self.coeff = coeff
        
        self.tonorm = 1.0

    def spinor(self,x,y):
        if self.direction == 0:
            p = vector((-1)**self.elec*self.mode*2*pi/self.sizeOfLattice,0.0,0.0)
        if self.direction == 1:
            p = vector(0.0,self.mode*2*pi/self.sizeOfLattice,0.0)

        s1upWf = self.tonorm*self.coeff*e**(-1j*dot(p,vector(x,y,0)-center))
        s1downWf = 0
        s2upWf = self.tonorm*self.coeff*e**(-1j*dot(p,vector(x,y,0)-center))
        s2downWf = 0
        
        return [s1upWf,s1downWf,s2upWf,s2downWf]
    
    def setNorm(self,tonorm):
        self.tonorm = tonorm



class Basis:
    def __init__(self, sizeOfLattice, sizeOfBasis, center, sigma, Z):
        self.sizeOfLattice = sizeOfLattice
        self.sizeOfBasis = sizeOfBasis

        self.sigma = sigma
        self.center = center
        self.Z = Z

        self.basis = [None]*2
        for elec in xrange(2):
            self.basis[elec] = [0]*2
            for direction in xrange(2):
                self.basis[elec][direction] = [0]*self.sizeOfBasis
                for mode in xrange(self.sizeOfBasis):
                    coeff = (1/self.sigma)*e**(-mode**2/self.sigma**2)
                    print coeff
                    self.basis[elec][direction][mode] = ElectronWave(sizeOfLattice, center, elec, direction, mode, coeff)

    def normalizeElectrons(self):
        for elec in xrange(2):
            norm = 0.0
            for x in xrange(self.sizeOfLattice):
                for y in xrange(self.sizeOfLattice):
                    norm += self.getTotalDensity(elec,x,y)
            self.setNorms(elec, 1/sqrt(norm))                    
            print "Normalization constant for electron:", elec, "was:", 1/sqrt(norm)

    def getTotalDensity(self,elec,x,y):
        spinorOne = [0.0,0.0,0.0,0.0]
        spinorTwo = [0.0,0.0,0.0,0.0]
        for direction in xrange(2):
            for mode in xrange(self.sizeOfBasis):
                spinorOne += self.basis[elec][direction][mode].spinor(x,y)
                spinorTwo += self.basis[elec][direction][mode].spinor(x,y)
        return dot(conjugate(spinorOne),spinorTwo).real

    def setNorms(self,elec,tonorm):
        for direction in xrange(2):
            for mode in xrange(self.sizeOfBasis):
                self.basis[elec][direction][mode].setNorm(tonorm)



class DensityPoint:
    def __init__(self):
        self.visibility = 1
        self.point = sphere(visible=self.visibility, radius=.001)

    def setAttributes(self,position,size,color):
        self.point.pos = position
        self.point.radius = size
        self.point.color = color

    def toggleVisibility(self):
        self.visibility = (1+self.visibility)%2


class Lattice:
    def __init__(self, basis,sizeOfLattice, sizeOfBasis, center):
        self.sizeOfLattice = sizeOfLattice
        self.sizeOfBasis = sizeOfBasis
        self.center = center

        scale = sizeOfLattice

        self.pointLattice=[None]*sizeOfLattice
        for x in xrange(sizeOfLattice):
            self.pointLattice[x]=[None]*sizeOfLattice
            for y in xrange(sizeOfLattice):
                position = vector(x,y,0.0)
                density = basis.getTotalDensity(0,x,y)
                print density
                size = scale*density
                color = (1.0,1.0,1.0)
                self.pointLattice[x][y] = DensityPoint()
                self.pointLattice[x][y].setAttributes(position,size,color)



sizeOfLattice = 16
sizeOfBasis = 8

m = 1.0
c = 1.0

center = vector((sizeOfLattice/2.0),(sizeOfLattice/2.0),0)
Z = 2.0
sigma = sizeOfLattice/6.0



scene = display(title='Density Plot', x=0, y=0, width=600, height=600, center=center, background=(0,0,0))



basis = Basis(sizeOfLattice, sizeOfBasis, center, sigma, Z)
##basis.normalizeElectrons()

##for x in xrange(sizeOfLattice):
##    for y in range(sizeOfLattice):
##        print basis.getTotalDensity(0,x,y)
##
##
##lattice = Lattice(basis, sizeOfLattice, sizeOfBasis, center)
